#include "charmm_param.h"
#include "cell.h"
#include "bonded.h"
// #define min(x, y) ((x) < (y)?(x):(y))
// #define max(x, y) ((x) > (y)?(x):(y))
//#define DEBUG_FORCE
void bonded_debug(long itag, long jtag, bond_param_t *p) {
//if (itag < jtag)
#ifdef DEBUG_BONDED
  printf("bond: %ld %ld %f %f\n", itag, jtag, p->r0, p->kr);
#endif
}

void angle_debug(long itag, long jtag, long ktag, angle_param_t *p) {
#ifdef DEBUG_BONDED
  printf("angle: %ld %ld %ld %f %f %f %f\n", itag, jtag, ktag, p->theta0, p->ktheta, p->r0ub, p->kub);
#endif
}

void torsion_debug(long itag, long jtag, long ktag, long ltag, tori_param_t *p, int nmult) {
#ifdef DEBUG_BONDED
  assert(nmult > 0);
  for (int i = 0; i < nmult; i++) {
    printf("dihedral: %ld %ld %ld %ld %f %f %f\n", itag, jtag, ktag, ltag, p[i].phi0, p[i].kphi, p[i].n);
  }
#endif
}

void improper_debug(long itag, long jtag, long ktag, long ltag, harmonic_impr_param_t *p) {
#ifdef DEBUG_BONDED
  printf("impr: %ld %ld %ld %ld %f %f\n", itag, jtag, ktag, ltag, p->phi0, p->kphi);
#endif
}
real bonded_force(vec<real> * fi, vec<real> * fj, vec<real> * xi, vec<real> * xj, bond_param_t *p) {
  if (p->shaked) return 0;
  //Please compute bonded force with i--j, x is coord and f is force, forces are added to fi, fj, energy shoule be return value
  vec<real> d;
  vecsubv(d, *xi, *xj);
  real r = vecnorm(d);
  real dr = r - p->r0;
  real coef = -2.0 * p->kr * dr / r;
#ifdef DEBUG_FORCE
  printf("%f %f %f\n", d.x * coef, d.y * coef, d.z * coef);
#endif
  //return 0;
  vecscaleaddv(*fi, *fi, d, 1, coef);
  vecscaleaddv(*fj, *fj, d, 1, -coef);
  return p->kr * dr * dr;
}

real angle_force(vec<real> * fi, vec<real> * fj, vec<real> * fk, vec<real> * xi, vec<real> * xj, vec<real> * xk, angle_param_t *p) {
  if (p->shaked) return 0;
  vec<real> dij, dkj;
  vecsubv(dij, *xi, *xj);
  vecsubv(dkj, *xk, *xj);

  real rinvij = 1. / vecnorm(dij);
  real rinvkj = 1. / vecnorm(dkj);

  real cos_th = vecdot(dij, dkj) * rinvij * rinvkj;
  cos_th = max(min(cos_th, 1), -1);
  real sin_th = sqrt(1 - cos_th * cos_th);
  real d_th = acos(cos_th) - p->theta0;
  real coef_th = -2 * p->ktheta * d_th / sin_th;
  real e_th = p->ktheta * d_th * d_th;

  vec<real> f1, f2, f3;
  vecscaleaddv(f1, dij, dkj, cos_th * rinvij, -rinvkj);
  vecscale(f1, f1, coef_th * rinvij);

  vecscaleaddv(f3, dkj, dij, cos_th * rinvkj, -rinvij);
  vecscale(f3, f3, coef_th * rinvkj);

  vecscaleaddv(f2, f1, f3, -1, -1);

  vec<real> dik;
  vecsubv(dik, *xk, *xi);
  real rik = vecnorm(dik);
  real d_ub = rik - p->r0ub;
  real coef_ub = -2. * p->kub * d_ub / rik;
  real e_ub = p->kub * d_ub * d_ub;

  vecscaleaddv(f1, f1, dik, 1, -coef_ub);
  vecscaleaddv(f3, f3, dik, 1, coef_ub);
#ifdef DEBUG_FORCE
  printf("%f %f %f %f %f %f %f %f %f\n", f1.x, f1.y, f1.z, f2.x, f2.y, f2.z, f3.x, f3.y, f3.z);
#endif
  //return 0;
  vecaddv(*fi, *fi, f1);
  vecaddv(*fj, *fj, f2);
  vecaddv(*fk, *fk, f3);
  return e_ub + e_th;
}
/*nmult: number of multiplicity*/
real torsion_force(vec<real> * fi, vec<real> * fj, vec<real> * fk, vec<real> * fl, vec<real> * xi, vec<real> * xj, vec<real> * xk, vec<real> * xl, tori_param_t *p, int nmult) {
  vec<real> dij, djk, dkl;
  vecsubv(dij, *xi, *xj);
  vecsubv(djk, *xj, *xk);
  vecsubv(dkl, *xk, *xl);
  vec<real> n1, n2, n3;
  veccross(n1, dij, djk);
  veccross(n2, djk, dkl);
  veccross(n3, djk, n1);
  real r1 = 1. / vecnorm(n1);
  real r2 = 1. / vecnorm(n2);
  real r3 = 1. / vecnorm(n3);

  vec<real> n1norm, n2norm, n3norm;
  vecscale(n1norm, n1, r1);
  vecscale(n2norm, n2, r2);
  vecscale(n3norm, n3, r3);
  real cos_phi = vecdot(n1norm, n2norm);
  real sin_phi = vecdot(n3norm, n2norm);

  real phi = -atan2(sin_phi, cos_phi);
  real e = 0, coef = 0;
  for (int i = 0; i < nmult; i++) {
    e += p[i].kphi * (1 + cos(p[i].n * phi - p[i].phi0));        //(phi - p->phi0) * (phi - p->phi0);
    coef += -p[i].n * p[i].kphi * sin(p[i].n * phi - p[i].phi0); //2 * p->kphi * (phi - p->phi0);
  }
  vec<real> f1, f2, f2l, f2r, f3;
  if (fabs(sin_phi) > 0.1) {
    vec<real> dcosdn1, dcosdn2;
    vecscaleaddv(dcosdn1, n1norm, n2norm, cos_phi, -1);
    vecscale(dcosdn1, dcosdn1, r1);
    vecscaleaddv(dcosdn2, n2norm, n1norm, cos_phi, -1);
    vecscale(dcosdn2, dcosdn2, r2);
    coef = coef / sin_phi;
    veccross(f1, djk, dcosdn1);
    vecscale(f1, f1, coef);
    veccross(f3, dcosdn2, djk);
    vecscale(f3, f3, coef);
    veccross(f2l, dcosdn1, dij);
    veccross(f2r, dkl, dcosdn2);
    vecscaleaddv(f2, f2l, f2r, coef, coef);
  } else {
    vec<real> dsindn2, dsindn3;
    vecscaleaddv(dsindn3, n3norm, n2norm, sin_phi, -1);
    vecscale(dsindn3, dsindn3, r3);
    vecscaleaddv(dsindn2, n2norm, n3norm, sin_phi, -1);
    vecscale(dsindn2, dsindn2, r2);

    coef = -coef / cos_phi;
    veccross3(f1, djk, dsindn3, djk);
    vecscale(f1, f1, coef);

    veccross(f3, dsindn2, djk);
    vecscale(f3, f3, coef);

    veccross3(f2l, dij, dsindn3, djk);
    veccross3(f2r, djk, dsindn3, dij);
    vecscaleaddv(f2l, f2l, f2r, 1, -2);
    veccross(f2r, dkl, dsindn2);
    vecaddv(f2, f2l, f2r);
    vecscale(f2, f2, coef);
  }

#ifdef DEBUG_FORCE
  printf("%f %f %f %f %f %f %f %f %f\n", f1.x, f1.y, f1.z, f2.x, f2.y, f2.z, f3.x, f3.y, f3.z);
#endif
  //return 0;
  vecaddv(*fi, *fi, f1);
  vecaddv(*fj, *fj, f2);
  vecsubv(*fj, *fj, f1);
  vecaddv(*fk, *fk, f3);
  vecsubv(*fk, *fk, f2);
  vecsubv(*fl, *fl, f3);
  return e;
}

real improper_force(vec<real> * fi, vec<real> * fj, vec<real> * fk, vec<real> * fl, vec<real> * xi, vec<real> * xj, vec<real> * xk, vec<real> * xl, harmonic_impr_param_t *p) {
  vec<real> dij, djk, dkl;
  vecsubv(dij, *xi, *xj);
  vecsubv(djk, *xj, *xk);
  vecsubv(dkl, *xk, *xl);
  vec<real> n1, n2, n3;
  veccross(n1, dij, djk);
  veccross(n2, djk, dkl);
  veccross(n3, djk, n1);
  real r1 = 1. / vecnorm(n1);
  real r2 = 1. / vecnorm(n2);
  real r3 = 1. / vecnorm(n3);

  vec<real> n1norm, n2norm, n3norm;
  vecscale(n1norm, n1, r1);
  vecscale(n2norm, n2, r2);
  vecscale(n3norm, n3, r3);
  real cos_phi = vecdot(n1norm, n2norm);
  real sin_phi = vecdot(n3norm, n2norm);

  real phi = -atan2(sin_phi, cos_phi);
  //real e = 0, coef = 0;

  real e = p->kphi * (phi - p->phi0) * (phi - p->phi0);
  real coef = 2 * p->kphi * (phi - p->phi0);
  vec<real> f1, f2, f2l, f2r, f3;
  if (fabs(sin_phi) > 0.1) {
    vec<real> dcosdn1, dcosdn2;
    vecscaleaddv(dcosdn1, n1norm, n2norm, cos_phi, -1);
    vecscale(dcosdn1, dcosdn1, r1);
    vecscaleaddv(dcosdn2, n2norm, n1norm, cos_phi, -1);
    vecscale(dcosdn2, dcosdn2, r2);
    coef = coef / sin_phi;
    veccross(f1, djk, dcosdn1);
    vecscale(f1, f1, coef);
    veccross(f3, dcosdn2, djk);
    vecscale(f3, f3, coef);
    veccross(f2l, dcosdn1, dij);
    veccross(f2r, dkl, dcosdn2);
    vecscaleaddv(f2, f2l, f2r, coef, coef);
  } else {
    vec<real> dsindn2, dsindn3;
    vecscaleaddv(dsindn3, n3norm, n2norm, sin_phi, -1);
    vecscale(dsindn3, dsindn3, r3);
    vecscaleaddv(dsindn2, n2norm, n3norm, sin_phi, -1);
    vecscale(dsindn2, dsindn2, r2);

    coef = -coef / cos_phi;
    veccross3(f1, djk, dsindn3, djk);
    vecscale(f1, f1, coef);

    veccross(f3, dsindn2, djk);
    vecscale(f3, f3, coef);

    veccross3(f2l, dij, dsindn3, djk);
    veccross3(f2r, djk, dsindn3, dij);
    vecscaleaddv(f2l, f2l, f2r, 1, -2);
    veccross(f2r, dkl, dsindn2);
    vecaddv(f2, f2l, f2r);
    vecscale(f2, f2, coef);
  }
#ifdef DEBUG_FORCE
  printf("%f %f %f %f %f %f %f %f %f\n", f1.x, f1.y, f1.z, f2.x, f2.y, f2.z, f3.x, f3.y, f3.z);
#endif
  //return 0;
  vecaddv(*fi, *fi, f1);
  vecaddv(*fj, *fj, f2);
  vecsubv(*fj, *fj, f1);
  vecaddv(*fk, *fk, f3);
  vecsubv(*fk, *fk, f2);
  vecsubv(*fl, *fl, f3);
  return e;
}

#ifdef __sw__
void charmm_bonded_sw(cellgrid_t *grid, charmm_param_t *param, mdstat_t *stat);
#endif

#include "esmd_conf.h"
void charmm_nonbonded(cellgrid_t *grid, charmm_param_t *param, mdstat_t *stat) {
  switch (param->nonb.coultype){
  case COUL_RF:
    nonbonded_force_lj_rf(grid, &param->nonb, stat);
    break;
  case COUL_MSM:
    nonbonded_force_lj_msm(grid, &param->nonb, stat);
    break;
  case COUL_SHIFTED:
    nonbonded_force_lj_shifted(grid, &param->nonb, stat);
    break;
  }
}
